import matplotlib
import numpy as np
from matplotlib import pyplot as plt
import analytics.analytic_funcs as AF
import math

defaultBinning = 100
sixColors=['red', 'blue', 'darkorange', 'darkviolet', 'gold', 'midnightblue']

def plotMean(imgs):
    plt.imshow(np.mean(imgs, axis=0, keepdims=False))  # средняя картинка по реальному датасету
    plt.colorbar()


def plotMeanWithTitle(imgs, title):
    plt.suptitle('ECAL: mean response of deposited energy', fontsize=16)
    plt.title(title)
    plotMean(imgs)


def plotMeanAbsDiff(ecal, fake) :
    plt.imshow(abs(np.mean(ecal, axis=0, keepdims=False) - np.mean(fake, axis=0, keepdims=False)))
    plt.colorbar()


def plotResponseImg(img, vmin=None, vmax=None, removeTicks=True, colorbar=True):
    show = plt.imshow(img, interpolation='nearest', vmin=vmin, vmax=vmax)
    # #todo : what to do if we got 0 as one of inputs here?
    if colorbar:
        plt.colorbar(show, fraction=0.046, pad=0.04)
    #plt.xlabel('cell X \n')
    #plt.ylabel('cell Y')
    if removeTicks:
        plt.xticks([])
        plt.yticks([])

def plotResponses(ecalData, logScale=True, fakeData = None, suptitle=None, sheets=1):
    depositeSupTitle = 'Examples of ' + ('logScaled ' if logScale else '') + 'energy deposite \n generated by GEANT4' \
                       + ('' if fakeData is None else '(left) vs. generated by GAN(right)') if suptitle is None else suptitle
    plt.suptitle(depositeSupTitle, fontsize=16)

    step = 4 if fakeData is not None else 8
    start = (sheets - 1)*step
    finish = start + step

    combined = np.concatenate((ecalData.response[start:finish], fakeData.response[start:finish])) if fakeData is not None else ecalData.response[start:finish]
    if logScale: combined = np.log10(combined)
    vmin, vmax = np.amin(np.ma.masked_invalid(combined)), np.amax(combined)

    for i in range(8):
        plt.subplot(421 + i)
        plotResponseImg(combined[4 * (i % 2) + i // 2], vmin, vmax)

def doPlotAssymetry(assymetry_real, orto, assymetry_fake = None, rangeByReal=True):
    range, postfix = comonHistRange(assymetry_real, assymetry_fake, rangeByReal)
    histReal, binz, patches = plt.hist(assymetry_real, bins=defaultBinning, range=range, color='red', alpha=0.3, density=True, label='Geant')
    if assymetry_fake is not None:
        histFake,_,_ = plt.hist(assymetry_fake,bins=defaultBinning, range=range, color='blue', alpha=0.3, density=True, label='GAN')
        postfix += '\n' + statsMsg(histFake, histReal)
    plt.xlabel(('Longitudinal' if orto else 'Transverse') + ' cluster asymmetry' + postfix)
    plt.legend(loc='best')

def doPlotAssymetryArr(assymetries, orto, rangeByReal=True):
    range, postfix = comonHistRangeArr(assymetries)
    for i,a in enumerate(assymetries) :
        odd = i % 2
        #label = ('Geant' if odd == 0 else 'GAN') + str(i // 2 + 1)
        color = 'blue' if odd else 'red'
        histReal, binz, patches = plt.hist(a, bins=120, range=range, color=sixColors[i], alpha=0.3, density=True)
    #plt.legend(['Geant', 'GAN'], loc='best')
    plt.xlabel(('Longitudinal' if orto else 'Transverse') + ' cluster asymmetry' + postfix)


def statsMsg(histFake, histReal):
    #chiStat, pVal = AF.chiSquare(histReal, histFake)  #todo : how to compute chiSquare stats for matplotlib/numpy hist ?
    #chiMsg = 'ChiSqr: { stats = %d , p-val = %d }' % (chiStat, pVal)
    l1msg  = 'L1 dist.: {:10.4f} '.format(AF.l1norm(histReal, histFake))
    l2msg  = 'L2 dist.: {:10.4f} '.format(AF.l2norm(histReal, histFake))
    return l1msg + ' ; ' + l2msg

def plotAssymWithNpHist(ecalAssym): #currently unused, represents template for stats calc
    #1st step : calculate histogram
    assymHist, binz = np.histogram(ecalAssym, bins=defaultBinning, range=[-1, 1], density=True)
    #2nd step : plot hist by weights; is equal to plt.hist(ecalAssym, bins=.....)
    plt.hist(binz[:-1], bins=len(binz), weights=assymHist, range=[-1, 1], color='red', alpha=0.3, density=True, label='Geant')

def doPlotShowerWidth(ecalWidth, orto, fakeWidth = None, rangeByReal=True):
    range, postfix = comonHistRange(ecalWidth, fakeWidth, rangeByReal)

    histReal,_,_ = plt.hist(ecalWidth, bins=defaultBinning, range=range, density=True, alpha=0.3, color='red', label='Geant')

    if fakeWidth is not None:
        histFake,_,_ = plt.hist(fakeWidth, bins=defaultBinning, range=range, density=True, alpha=0.3, color='blue', label='GAN')
        postfix += '\n' + statsMsg(histFake, histReal)
    plt.legend(loc='best')
    plt.xlabel(('Longitudinal' if orto else 'Transverse') + ' cluster width [cm]' + postfix)
    plt.ylabel('Arbitrary units')

def doPlotShowerWidthArr(widths, orto, rangeByReal=True):
    range, postfix = comonHistRangeArr(widths)
    for i,w in enumerate(widths) :
        odd = i % 2
        #label = ('Geant' if odd == 0 else 'GAN') + str(i // 2 + 1)
        #color = 'blue' if odd else 'red'
        histReal, _, _ = plt.hist(w, bins=120, range=range, density=True, alpha=0.3, color=sixColors[i])

    plt.legend(['Geant (a)', 'GAN (a)', 'Geant (b)', 'GAN (b)', 'Geant (c)', 'GAN (c)'], loc='best', prop={'size': 12})
    plt.xlabel(('Longitudinal' if orto else 'Transverse') + ' cluster width [cm]' + postfix)
    plt.ylabel('Arbitrary units')

def doPlotSingleSparsity(sparsity, alpha, color='red'):
    means = np.mean(sparsity, axis=0)
    stddev = np.std(sparsity, axis=0)
    plt.plot(alpha, means, color=color)
    plt.fill_between(alpha, means - stddev, means + stddev, color=color, alpha=0.3)


def doPlotSparsity(ecalSparsity, alpha, fakeSparsity=None):
    doPlotSingleSparsity(np.array(ecalSparsity), alpha)
    legend = ['Geant']
    if fakeSparsity is not None:
        doPlotSingleSparsity(np.array(fakeSparsity), alpha, color='blue')
        legend.append('GAN')

    plt.legend(legend)
    plt.title('Sparsity')
    plt.xlabel('log10(Threshold/GeV)')
    plt.ylabel('Fraction of cells above threshold')


def plotEnergies(ecal, logScaled, fake=None, rangeByReal=True, size=30):
    plt.suptitle('Energy deposited in {}x{} '.format(size, size), fontsize=16)
    yPostfix = ', LogScale' if logScaled else ''
    commonRange, xPostfix = comonHistRange(ecal, fake, rangeByReal)

    histReal,_,_ = plt.hist(ecal, 100, range=commonRange, log=logScaled, color='red', alpha=0.3)
    legend = ['Geant']

    if fake is not None:
        histFake,_,_ = plt.hist(fake, 100, range=commonRange, log=logScaled, color='blue', alpha=0.3)
        if logScaled :
            histFake[histFake == 0] = 1
            histReal[histReal == 0] = 1
            histFake, histReal = np.log10(histFake), np.log10(histReal)
        xPostfix += '\n' + statsMsg(histFake, histReal)
        legend.append('GAN')

    plt.legend(legend)
    plt.xlabel('Energy ' + xPostfix)
    plt.ylabel('Arbitrary Units' + yPostfix)

def plotEnergiesArr(energies, logScaled, rangeByReal=True, size=30):
    yPostfix = ', LogScale' if logScaled else ''
    commonRange, xPostfix = comonHistRangeArr(energies)

    legend = ['Geant', 'GAN']

    for i,e in enumerate(energies) :
        odd = i % 2
        #label = ('Geant' if odd == 0 else 'GAN') + str(i // 2 + 1)
        #color = 'blue' if odd else 'red'
        print('warning: hist for 10k first energy responses')
        histReal, _, _ = plt.hist(e[:10000], 120, range=commonRange, log=logScaled, color=sixColors[i], alpha=0.3)
    #plt.legend(legend)
    plt.xlabel('Energy deposited in {}x{} '.format(size, size) + xPostfix)
    #plt.ylabel('Arbitrary Units' + yPostfix)

import seaborn as sns
import pandas as pd
def doPlotClusterShape(energies, centralEnergies, color, allPoints, range):
    if allPoints:
        nrgBins = energies
    else:
        #nrgSpace = np.linspace(range[0] - 1, range[1] + 1, 201)
        nrgSpace = np.linspace(325, 525, 20)
        binnedNrg = pd.cut(energies, bins=nrgSpace)
        nrgBins = binnedNrg.categories.mid[binnedNrg.codes].values
    sns.lineplot(x=nrgBins, y=np.true_divide(centralEnergies, energies), color=color, ci=99)

def doPlotClShape(energies, ecalCentralEnergies, logScale, fakeEnergies=None, fakeCentralEnergies=None, allPoints=True):
    scaleToLog10 = lambda nparray : log10IfRequired(nparray, logScale)

    commonRange, xPostfix = comonHistRange(energies, fakeEnergies, False)

    ecalNrg = np.array(energies)
    ecalOrder = np.argsort(ecalNrg)
    doPlotClusterShape(scaleToLog10(ecalNrg[ecalOrder]), scaleToLog10(np.array(ecalCentralEnergies)[ecalOrder]), 'red', allPoints, scaleToLog10(commonRange))
    legend = ['Geant']
    if fakeEnergies is not None:
        fakeNrg= np.array(fakeEnergies)
        fakeOrder = np.argsort(fakeNrg)
        doPlotClusterShape(scaleToLog10(fakeNrg[fakeOrder]), scaleToLog10(np.array(fakeCentralEnergies)[fakeOrder]), 'blue', allPoints, scaleToLog10(commonRange))
        legend.append('GAN')

    plt.legend(legend)
    plt.xlabel('Total Energy Deposit')
    plt.ylabel('Fraction of centered nrg by total')

def log10IfRequired(nparray, logScaled):
    return nparray if not logScaled else np.log10(nparray)


def comonHistRange(ecal, fake, rangeByReal):
    maxE = np.amax(ecal)
    minE = np.amin(ecal)
    if fake is not None and not rangeByReal:
        maxE = max(maxE, np.amax(fake))
        minE = min(minE, np.amin(fake))
    return [minE, maxE], '' if rangeByReal else ', common range'

def comonHistRangeArr(ecal):
    maxE = -math.inf
    minE = math.inf
    for i,e in enumerate(ecal):
        maxE = max(maxE, np.amax(e))
        minE = min(minE, np.amin(e))
    return [minE, maxE], ''